{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Multicomponent models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from chemprop.nn.message_passing import MulticomponentMessagePassing\n",
    "from chemprop.models import MulticomponentMPNN"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Overview"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The basic Chemprop model is designed for a single molecule or reaction as input. A multicomponent Chemprop model organizes these basic building blocks to take multiple molecules/reactions as input. This is useful for properties that depend on multiple components like properties in solvents."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Message passing"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`MulticomponentMessagePassing` organizes the single component [message passing](./message_passing.ipynb) modules for each component in the multicomponent dataset. The individual message passing modules can be unique for each component, shared between some components, or shared between all components. If all components share the same message passing module, the shared flag can be set to True. Note that it doesn't make sense for components that use different featurizers (e.g. molecules and reactions) to use the same message passing module."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from chemprop.nn import BondMessagePassing\n",
    "\n",
    "mp1 = BondMessagePassing(d_h=100)\n",
    "mp2 = BondMessagePassing(d_h=600)\n",
    "blocks = [mp1, mp2]\n",
    "mcmp = MulticomponentMessagePassing(blocks=blocks, n_components=len(blocks))\n",
    "\n",
    "mp = BondMessagePassing()\n",
    "mcmp = MulticomponentMessagePassing(blocks=[mp], n_components=2, shared=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "During the forward pass of the model, the output of each message passing block is concatentated after aggregation as input to the predictor."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Aggregation\n",
    "\n",
    "A single [aggregation](./aggregation.ipynb) module is used on all message passing outputs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from chemprop.nn import MeanAggregation\n",
    "\n",
    "agg = MeanAggregation()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Predictor"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The [predictor](./predictor.ipynb) needs to be told the output dimension of the message passing layer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from chemprop.nn import RegressionFFN\n",
    "\n",
    "ffn = RegressionFFN(input_dim=mcmp.output_dim)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Multicomponent MPNN"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The submodules are composed together in a `MulticomponentMPNN` model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MulticomponentMPNN(\n",
       "  (message_passing): MulticomponentMessagePassing(\n",
       "    (blocks): ModuleList(\n",
       "      (0-1): 2 x BondMessagePassing(\n",
       "        (W_i): Linear(in_features=86, out_features=300, bias=False)\n",
       "        (W_h): Linear(in_features=300, out_features=300, bias=False)\n",
       "        (W_o): Linear(in_features=372, out_features=300, bias=True)\n",
       "        (dropout): Dropout(p=0.0, inplace=False)\n",
       "        (tau): ReLU()\n",
       "        (V_d_transform): Identity()\n",
       "        (graph_transform): Identity()\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (agg): MeanAggregation()\n",
       "  (bn): Identity()\n",
       "  (predictor): RegressionFFN(\n",
       "    (ffn): MLP(\n",
       "      (0): Sequential(\n",
       "        (0): Linear(in_features=600, out_features=300, bias=True)\n",
       "      )\n",
       "      (1): Sequential(\n",
       "        (0): ReLU()\n",
       "        (1): Dropout(p=0.0, inplace=False)\n",
       "        (2): Linear(in_features=300, out_features=1, bias=True)\n",
       "      )\n",
       "    )\n",
       "    (criterion): MSE(task_weights=[[1.0]])\n",
       "    (output_transform): Identity()\n",
       "  )\n",
       "  (X_d_transform): Identity()\n",
       "  (metrics): ModuleList(\n",
       "    (0-1): 2 x MSE(task_weights=[[1.0]])\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mc_model = MulticomponentMPNN(mcmp, agg, ffn)\n",
    "mc_model"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "chemprop",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
